{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "BERT For Patents",
      "provenance": [
        {
          "file_id": "1d9KurXhXvrV-jo-x2f7DkZ40qx75YAh_",
          "timestamp": 1604694308174
        }
      ],
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//corp/legal/patents/colab:dst_colab_notebook",
        "kind": "shared"
      },
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ED6tBdZtOjlU"
      },
      "source": [
        "# BERT for Patents"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CqNm7ioGOgSm"
      },
      "source": [
        "Copyright 2020 Google Inc.\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at\n",
        "\n",
        "http://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "c1vLcDJINTGg"
      },
      "source": [
        "import collections\n",
        "import math\n",
        "import random\n",
        "import sys\n",
        "import time\n",
        "from typing import Dict, List, Tuple\n",
        "\n",
        "# Use Tensorflow 2.0\n",
        "import tensorflow as tf\n",
        "import numpy as np"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "vfSIZaeaPHpZ"
      },
      "source": [
        "# Set BigQuery application credentials\n",
        "from google.cloud import bigquery\n",
        "import os\n",
        "os.environ[\"GOOGLE_APPLICATION_CREDENTIALS\"] = \"path/to/file.json\"\n",
        "\n",
        "project_id = \"your_bq_project_id\"\n",
        "bq_client = bigquery.Client(project=project_id)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7BojUHDYrESY"
      },
      "source": [
        "# You will have to clone the BERT repo\n",
        "!test -d bert_repo || git clone https://github.com/google-research/bert bert_repo\n",
        "if not 'bert_repo' in sys.path:\n",
        "  sys.path += ['bert_repo']"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QeoX7LfgPLGP"
      },
      "source": [
        "The BERT repo uses Tensorflow 1 and thus a few of the functions have been moved/changed/renamed in Tensorflow 2. In order for the BERT tokenizer to be used, one of the lines in the repo that was just cloned needs to be modified to comply with Tensorflow 2. Line 125 in the BERT tokenization.py file must be changed as follows:\n",
        "\n",
        "From => `with tf.gfile.GFile(vocab_file, \"r\") as reader:`\n",
        "\n",
        "To => `with tf.io.gfile.GFile(vocab_file, \"r\") as reader:`\n",
        "\n",
        "Once that is complete and the file is saved, the tokenization library can be imported."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "HsSJXKPDPLXn"
      },
      "source": [
        "import tokenization"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mXAmqo36RJXk"
      },
      "source": [
        "## Set Up BERT and Some Helpers\n",
        "\n",
        "The BERT exported here has been trained on >100 million patent documents and was trained on all parts of a patent (abstract, claims, description).\n",
        "\n",
        "The BERT model exported here comes in two formats:\n",
        "\n",
        "* [SavedModel](https://storage.googleapis.com/patents-public-data-github/saved_model.zip)\n",
        "\n",
        "* [Checkpoint](https://storage.googleapis.com/patents-public-data-github/checkpoint.zip)\n",
        "\n",
        "**NOTE: This notebook uses the saved model format.**\n",
        "\n",
        "The models can also be loaded and saved in another format or just the weights can be saved.\n",
        "\n",
        "The BERT model has been trained on >100 million patent documents and was trained on all parts of a patent (abstract, claims, description). It has a similar configuration to the BERT-Large model, with a couple of important notes:\n",
        "\n",
        "* The maximum input sequence length is 512 tokens and maximum masked words for a sequence is 45.\n",
        "* The vocabulary has approximately 9000 added words from the standard BERT vocabulary. These represent frequently used patent terms.\n",
        "* The vocabulary includes \"context\" tokens indicating what part of a patent the text is from (abstract, claims, summary, invention). Providing context tokens in the examples is optional.\n",
        "\n",
        "The full BERT vocabulary can be downloaded [here](https://storage.googleapis.com/patents-public-data-github/bert_for_patents_vocab_39k.txt). The vocabulary also contains 1000 unused tokens so that more tokens can be added.\n",
        "\n",
        "The exact configuration for the BERT model is as follows (and downloaded [here](https://storage.googleapis.com/patents-public-data-github/bert_for_patents_large_config.json)):\n",
        "\n",
        "* attention_probs_dropout_prob: 0.1\n",
        "* hidden_act: gelu\n",
        "* hidden_dropout_prob: 0.1\n",
        "* hidden_size: 1024\n",
        "* initializer_range: 0.02\n",
        "* intermediate_size: 4096\n",
        "* max_position_embeddings: 512\n",
        "* num_attention_heads: 16\n",
        "* num_hidden_layers: 24\n",
        "* vocab_size: 39859\n",
        "\n",
        "The model has requires the following input signatures:\n",
        "1. `input_ids`\n",
        "2. `input_mask`\n",
        "3. `segment_ids`\n",
        "4. `mlm_ids`\n",
        "\n",
        "And the BERT model contains output signatures for:\n",
        "1. `cls_token`\n",
        "2. `encoder_layer` is the contextualized word embeddings from the last encoder layer.\n",
        "3. `mlm_logits` is the predictions for any masked tokens provided to the model."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "BgGZ4mNcQ1NJ"
      },
      "source": [
        "# The functions in this block are also found in the bert cloned repo in the \n",
        "# `run_classifier.py` file, however those also have some compatibility issues \n",
        "# and thus the functions needed are just copied here.\n",
        "\n",
        "class InputFeatures(object):\n",
        "  \"\"\"A single set of features of data.\"\"\"\n",
        "\n",
        "  def __init__(self, input_ids, input_mask, segment_ids, label_id,\n",
        "               is_real_example=True):\n",
        "    self.input_ids = input_ids\n",
        "    self.input_mask = input_mask\n",
        "    self.segment_ids = segment_ids\n",
        "    self.label_id = label_id\n",
        "    self.is_real_example = is_real_example\n",
        "\n",
        "class InputExample(object):\n",
        "  \"\"\"A single training/test example for simple sequence classification.\"\"\"\n",
        "\n",
        "  def __init__(self, guid, text_a, text_b=None, label=None):\n",
        "    \"\"\"Constructs a InputExample.\"\"\"\n",
        "    self.guid = guid\n",
        "    self.text_a = text_a\n",
        "    self.text_b = text_b\n",
        "    self.label = label\n",
        "\n",
        "def _truncate_seq_pair(tokens_a, tokens_b, max_length):\n",
        "  \"\"\"Truncates a sequence pair in place to the maximum length.\"\"\"\n",
        "  while True:\n",
        "    total_length = len(tokens_a) + len(tokens_b)\n",
        "    if total_length <= max_length:\n",
        "      break\n",
        "    if len(tokens_a) > len(tokens_b):\n",
        "      tokens_a.pop()\n",
        "    else:\n",
        "      tokens_b.pop()\n",
        "\n",
        "def convert_examples_to_features(examples, label_list, max_seq_length,\n",
        "                                 tokenizer):\n",
        "  \"\"\"Convert a set of `InputExample`s to a list of `InputFeatures`.\"\"\"\n",
        "\n",
        "  features = []\n",
        "  for (ex_index, example) in enumerate(examples):\n",
        "    feature = convert_single_example(ex_index, example, label_list,\n",
        "                                     max_seq_length, tokenizer)\n",
        "    features.append(feature)\n",
        "  return features\n",
        "\n",
        "def convert_single_example(ex_index, example, label_list, max_seq_length,\n",
        "                           tokenizer):\n",
        "  \"\"\"Converts a single `InputExample` into a single `InputFeatures`.\"\"\"\n",
        "\n",
        "  label_map = {}\n",
        "  for (i, label) in enumerate(label_list):\n",
        "    label_map[label] = i\n",
        "\n",
        "  tokens_a = tokenizer.tokenize(example.text_a)\n",
        "  tokens_b = None\n",
        "  if example.text_b:\n",
        "    tokens_b = tokenizer.tokenize(example.text_b)\n",
        "\n",
        "  if tokens_b:\n",
        "    # Modifies `tokens_a` and `tokens_b` in place so that the total\n",
        "    # length is less than the specified length.\n",
        "    # Account for [CLS], [SEP], [SEP] with \"- 3\"\n",
        "    _truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)\n",
        "  else:\n",
        "    # Account for [CLS] and [SEP] with \"- 2\"\n",
        "    if len(tokens_a) > max_seq_length - 2:\n",
        "      tokens_a = tokens_a[0:(max_seq_length - 2)]\n",
        "\n",
        "  tokens = []\n",
        "  segment_ids = []\n",
        "  tokens.append(\"[CLS]\")\n",
        "  segment_ids.append(0)\n",
        "  for token in tokens_a:\n",
        "    tokens.append(token)\n",
        "    segment_ids.append(0)\n",
        "  tokens.append(\"[SEP]\")\n",
        "  segment_ids.append(0)\n",
        "\n",
        "  if tokens_b:\n",
        "    for token in tokens_b:\n",
        "      tokens.append(token)\n",
        "      segment_ids.append(1)\n",
        "    tokens.append(\"[SEP]\")\n",
        "    segment_ids.append(1)\n",
        "\n",
        "  input_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
        "\n",
        "  # The mask has 1 for real tokens and 0 for padding tokens. Only real\n",
        "  # tokens are attended to.\n",
        "  input_mask = [1] * len(input_ids)\n",
        "\n",
        "  # Zero-pad up to the sequence length.\n",
        "  while len(input_ids) < max_seq_length:\n",
        "    input_ids.append(0)\n",
        "    input_mask.append(0)\n",
        "    segment_ids.append(0)\n",
        "\n",
        "  assert len(input_ids) == max_seq_length\n",
        "  assert len(input_mask) == max_seq_length\n",
        "  assert len(segment_ids) == max_seq_length\n",
        "\n",
        "  label_id = label_map[example.label]\n",
        "\n",
        "  feature = InputFeatures(\n",
        "      input_ids=input_ids,\n",
        "      input_mask=input_mask,\n",
        "      segment_ids=segment_ids,\n",
        "      label_id=label_id,\n",
        "      is_real_example=True)\n",
        "  return feature"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "eZDhP_S1IA9Z"
      },
      "source": [
        "def get_tokenized_input(\n",
        "    texts: List[str], tokenizer: tokenization.FullTokenizer) -> List[List[int]]:\n",
        "  \"\"\"Returns list of tokenized text segments.\"\"\"\n",
        "\n",
        "  return [tokenizer.tokenize(text) for text in texts]\n",
        "\n",
        "\n",
        "class BertPredictor():\n",
        "\n",
        "  def __init__(\n",
        "      self, \n",
        "      model_name: str, \n",
        "      text_tokenizer: tokenization.FullTokenizer, \n",
        "      max_seq_length: int,\n",
        "      max_preds_per_seq: int,\n",
        "      has_context: bool = False):\n",
        "    \"\"\"Initializes a BertPredictor object.\"\"\"\n",
        "\n",
        "    self.tokenizer = text_tokenizer\n",
        "    self.max_seq_length = max_seq_length\n",
        "    self.max_preds_per_seq = max_preds_per_seq\n",
        "    self.mask_token_id = 4\n",
        "    # If you want to add context tokens to the input, set value to True.\n",
        "    self.context = has_context\n",
        "\n",
        "    model = tf.compat.v2.saved_model.load(export_dir=model_name, tags=['serve'])\n",
        "    self.model = model.signatures['serving_default']\n",
        "\n",
        "  def get_features_from_texts(self, texts: List[str]) -> Dict[str, int]:\n",
        "    \"\"\"Uses tokenizer to convert raw text into features for prediction.\"\"\"\n",
        "\n",
        "    #examples = [run_classifier.InputExample(0, t, label='') for t in texts]\n",
        "    #features = run_classifier.convert_examples_to_features(\n",
        "    #    examples, [''], self.max_seq_length, self.tokenizer)\n",
        "    examples = [InputExample(0, t, label='') for t in texts]\n",
        "    features = convert_examples_to_features(\n",
        "        examples, [''], self.max_seq_length, self.tokenizer)\n",
        "    return dict(\n",
        "        input_ids=[f.input_ids for f in features],\n",
        "        input_mask=[f.input_mask for f in features],\n",
        "        segment_ids=[f.segment_ids for f in features]\n",
        "    )\n",
        "\n",
        "  def insert_token(self, input: List[int], token: int) -> List[int]:\n",
        "    \"\"\"Adds token to input.\"\"\"\n",
        "\n",
        "    return input[:1] + [token] + input[1:-1]\n",
        "\n",
        "  def add_input_context(\n",
        "      self, inputs: Dict[str, int], context_tokens: List[str]\n",
        "  ) -> Dict[str, int]:\n",
        "    \"\"\"Adds context token to input features.\"\"\"\n",
        "\n",
        "    context_token_ids = self.tokenizer.convert_tokens_to_ids(context_tokens)\n",
        "    segment_token_id = 0\n",
        "    mask_token_id = 1\n",
        "\n",
        "    for i, context_token_id in enumerate(context_token_ids):\n",
        "      inputs['input_ids'][i] = self.insert_token(\n",
        "          inputs['input_ids'][i], context_token_id)\n",
        "\n",
        "      inputs['segment_ids'][i] = self.insert_token(\n",
        "          inputs['segment_ids'][i], segment_token_id)\n",
        "\n",
        "      inputs['input_mask'][i] = self.insert_token(\n",
        "          inputs['input_mask'][i], mask_token_id)\n",
        "    return inputs\n",
        "\n",
        "  def create_mlm_mask(\n",
        "      self, inputs: Dict[str, int], mlm_ids: List[List[int]]\n",
        "  ) -> Tuple[Dict[str, List[List[int]]], List[List[str]]]:\n",
        "    \"\"\"Creates masked language model mask.\"\"\"\n",
        "\n",
        "    masked_text_tokens = []\n",
        "    mlm_positions = []\n",
        "\n",
        "    if not mlm_ids:\n",
        "      inputs['mlm_ids'] = mlm_positions\n",
        "      return inputs, masked_text_tokens\n",
        "\n",
        "    for i, _ in enumerate(mlm_ids):\n",
        "\n",
        "      masked_text = []\n",
        "\n",
        "      # Pad mlm positions to max seqeuence length.\n",
        "      mlm_positions.append(\n",
        "          mlm_ids[i] + [0] * (self.max_preds_per_seq - len(mlm_ids[i])))\n",
        "\n",
        "      for pos in mlm_ids[i]:\n",
        "        # Retrieve the masked token.\n",
        "        masked_text.extend(\n",
        "            self.tokenizer.convert_ids_to_tokens([inputs['input_ids'][i][pos]]))\n",
        "        # Replace the mask positions with the mask token.\n",
        "        inputs['input_ids'][i][pos] = self.mask_token_id\n",
        "  \n",
        "      masked_text_tokens.append(masked_text)\n",
        "\n",
        "    inputs['mlm_ids'] = mlm_positions\n",
        "    return inputs, masked_text_tokens\n",
        "\n",
        "  def predict(\n",
        "      self, texts: List[str], mlm_ids: List[List[int]] = None, \n",
        "      context_tokens: List[str] = None\n",
        "  ) -> Tuple[Dict[str, tf.Tensor], Dict[str, List[List[int]]], List[List[str]]]:\n",
        "    \"\"\"Gets BERT predictions for provided text and masks.\n",
        "    \n",
        "    Args:\n",
        "      texts: List of texts to get BERT predictions.\n",
        "      mlm_ids: List of lists corresponding to the mask positions for each input\n",
        "        in `texts`.\n",
        "      context_token: List of string contexts to prepend to input texts.\n",
        "\n",
        "    Returns:\n",
        "      response: BERT model response.\n",
        "      inputs: Tokenized and modified input to BERT model.\n",
        "      masked_text: Raw strings of the masked tokens.\n",
        "    \"\"\"\n",
        "\n",
        "    if mlm_ids:\n",
        "      assert len(mlm_ids) == len(texts), ('If mask ids provided, they must be '\n",
        "          'equal to the length of the input text.')\n",
        "\n",
        "    if self.context:\n",
        "      # If model uses context, but none provided, use 'UNK' token for context.\n",
        "      if not context_tokens:\n",
        "        context_tokens = ['[UNK]' for _ in range(len(texts))]\n",
        "      assert len(context_tokens) == len(texts), ('If context tokens provided, '\n",
        "          'they must be equal to the length of the input text.')\n",
        "    \n",
        "    inputs = self.get_features_from_texts(texts)\n",
        "\n",
        "    # If using a BERT model with context, add corresponding tokens.\n",
        "    if self.context:\n",
        "      inputs = self.add_input_context(inputs, context_tokens)\n",
        "\n",
        "    inputs, masked_text = self.create_mlm_mask(inputs, mlm_ids)\n",
        "\n",
        "    response = self.model(\n",
        "      segment_ids=tf.convert_to_tensor(inputs['segment_ids'], dtype=tf.int64),\n",
        "      input_mask=tf.convert_to_tensor(inputs['input_mask'], dtype=tf.int64),\n",
        "      input_ids=tf.convert_to_tensor(inputs['input_ids'], dtype=tf.int64),\n",
        "      mlm_positions=tf.convert_to_tensor(inputs['mlm_ids'], dtype=tf.int64),\n",
        "      )\n",
        "    \n",
        "    if mlm_ids:\n",
        "      # Do a reshape of the mlm logits (batch size, num predictions, vocab).\n",
        "      new_shape = (len(texts), self.max_preds_per_seq, -1)\n",
        "      response['mlm_logits'] = tf.reshape(\n",
        "          response['mlm_logits'], shape=new_shape)\n",
        "    \n",
        "    return response, inputs, masked_text \n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "w8iPAdupddm4"
      },
      "source": [
        "# Some helper functions.\n",
        "\n",
        "def get_mlm_ids_by_token(\n",
        "    mask_token: str, tokenized_text: List[List[str]], \n",
        "    has_context: bool = False, first_occurence: bool = True\n",
        ") -> List[List[int]]:\n",
        "  \"\"\"Returns position ids for masking a specified token.\"\"\"\n",
        "\n",
        "  pos_add = 2 if has_context else 1\n",
        "  mlm_ids = []\n",
        "  for i, tokens in enumerate(tokenized_text):\n",
        "    pub_mlm_ids = []\n",
        "    for j, token in enumerate(tokens):\n",
        "      if token == mask_token:\n",
        "        pub_mlm_ids.append(j + pos_add)\n",
        "        if first_occurence:\n",
        "          break\n",
        "    mlm_ids.append(pub_mlm_ids)\n",
        "\n",
        "  return mlm_ids\n",
        "\n",
        "\n",
        "def bert_topk_predictions(\n",
        "    mlm_logits: tf.Tensor, mlm_ids: List[List[int]], top_k: int = 5\n",
        ") -> Tuple[List[int], List[str]]:\n",
        "  \"\"\"Returns BERT predicted token ids and terms for masked ids.\n",
        "  \n",
        "  Args:\n",
        "    mlm_logits: The BERT masked language logits.\n",
        "    mlm_ids: The masked ids.\n",
        "    top_k: Number of predictions to return for each mask.\n",
        "\n",
        "  Returns:\n",
        "    token_preds: Token predictions for each mask position.\n",
        "    term_preds: Term predictions for each mask position.\n",
        "  \"\"\"\n",
        "\n",
        "  token_preds = []\n",
        "  term_preds = []\n",
        "\n",
        "  # Tradeoff between single call for all (including non masked) and then gather\n",
        "  # vs. calling math top_k over and over\n",
        "\n",
        "  for i, ids in enumerate(mlm_ids):\n",
        "    current_token_preds = []\n",
        "    current_term_preds = []\n",
        "    for j, id in enumerate(ids):\n",
        "      preds = tf.math.top_k(mlm_logits[i][j], top_k).indices.numpy().tolist()\n",
        "      current_token_preds.append(preds)\n",
        "      current_term_preds.append(tokenizer.convert_ids_to_tokens(preds))\n",
        "    token_preds.append(current_token_preds)\n",
        "    term_preds.append(current_term_preds)\n",
        "\n",
        "  return token_preds, term_preds\n",
        "\n",
        "\n",
        "def find_rankings(\n",
        "    words: List[str], word_ids: List[int], mlm_logits: tf.Tensor, \n",
        "    mlm_ids: List[List[str]]\n",
        ") -> Dict[str, float]:\n",
        "  \"\"\"Return the rankings in the bert predictions for the provided words.\"\"\"\n",
        "  \n",
        "  word_positions = []\n",
        "\n",
        "  # Iterate through all predictions.\n",
        "  for i, _ in enumerate(mlm_ids):\n",
        "    for j, _ in enumerate(mlm_ids[i]):\n",
        "      logits = tf.argsort(mlm_logits[i][j], direction='DESCENDING')\n",
        "      positions = tf.reshape(tf.where(tf.equal(\n",
        "          tf.expand_dims(word_ids, axis=-1), logits))[:,-1], [1, -1])\n",
        "      word_positions.extend(list(positions.numpy()))\n",
        "\n",
        "  transposed = np.array(word_positions).T\n",
        "  word_dict = dict()\n",
        "\n",
        "  for i, word in enumerate(words):\n",
        "    total = sum(transposed[i])\n",
        "    word_dict[word] = {\n",
        "        'average': transposed[i].mean(),\n",
        "        'max': transposed[i].max(),\n",
        "        'min': transposed[i].min(),\n",
        "        'std': transposed[i].std(),\n",
        "    }\n",
        "\n",
        "  return word_dict"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JBqRRfigQxxK"
      },
      "source": [
        "# Load BERT"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "kp2fx508lWBG"
      },
      "source": [
        "MAX_SEQ_LENGTH = 512\n",
        "MAX_PREDS_PER_SEQUENCE = 45\n",
        "MODEL_DIR = 'path/to/bert/model/'\n",
        "VOCAB = 'path/to/vocab.txt'\n",
        "\n",
        "tokenizer = tokenization.FullTokenizer(VOCAB, do_lower_case=True)\n",
        "\n",
        "bert_predictor = BertPredictor(\n",
        "    model_name=MODEL_DIR,\n",
        "    text_tokenizer=tokenizer,\n",
        "    max_seq_length=MAX_SEQ_LENGTH,\n",
        "    max_preds_per_seq=MAX_PREDS_PER_SEQUENCE,\n",
        "    has_context=False)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rtzZg5LESCxF"
      },
      "source": [
        "## Masked Term Example from Patent Abstracts\n",
        "\n",
        "Here we do a simple query from the BigQuery patents data to collect the abstract for 3 different patent abstracts that use the word \"eye\" and print our their predictions to see how the synonyms change for the same word as the patent changes."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "u3iTTJQ5SFba"
      },
      "source": [
        "test_pubs = ('US-8000000-B2', 'US-2007186831-A1', 'US-2009030261-A1')\n",
        "\n",
        "query = r\"\"\"\n",
        "  SELECT publication_number, abstract, url\n",
        "  FROM `patents-public-data.google_patents_research.publications` \n",
        "  WHERE publication_number in {}\n",
        "\"\"\".format(test_pubs)\n",
        "\n",
        "df = bq_client.query(query).to_dataframe()"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Rg3R5d_OBEUe"
      },
      "source": [
        "tokenized_inputs = get_tokenized_input(df.abstract.to_list(), tokenizer)\n",
        "mlm_ids = get_mlm_ids_by_token('eye', tokenized_inputs)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "lgyFtMw4Bj4k"
      },
      "source": [
        "response, inputs, masked_text = bert_predictor.predict(\n",
        "    df.abstract.to_list(), mlm_ids)\n",
        "\n",
        "token_preds, term_preds = bert_topk_predictions(response['mlm_logits'], mlm_ids)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Q4l2iAH72p97",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1605127862594,
          "user_tz": 300,
          "elapsed": 340,
          "user": {
            "displayName": "Rob Srebrovic",
            "photoUrl": "",
            "userId": "06004353344935214283"
          }
        },
        "outputId": "c8080b64-f227-4117-d335-7336ccaedfe5"
      },
      "source": [
        "for row, terms in zip(df.values.tolist(), term_preds):\n",
        "  out = 'Patent: {}. ({})\\nAbstract: {}\\nPredictions of term eye \\n\\t{}\\n'\n",
        "  print(out.format(row[0], row[2], row[1][:100]+'...', terms))"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Patent: US-2007186831-A1. (https://patents.google.com/patent/US20070186831A1)\n",
            "Abstract: A sewing machine includes a thread take-up, a thread take-up driving mechanism driving the thread ta...\n",
            "Predictions of term eye \n",
            "\t[['hole', 'point', 'drop', 'eye', 'tip']]\n",
            "\n",
            "Patent: US-8000000-B2. (https://patents.google.com/patent/US8000000B2)\n",
            "Abstract: A visual prosthesis apparatus and a method for limiting power consumption in a visual prosthesis app...\n",
            "Predictions of term eye \n",
            "\t[['eye', 'retina', 'eyes', 'brain', 'eyeball']]\n",
            "\n",
            "Patent: US-2009030261-A1. (https://patents.google.com/patent/US20090030261A1)\n",
            "Abstract: Currently, no efficient, non-invasive methods exist for delivering drugs and/or other therapeutic ag...\n",
            "Predictions of term eye \n",
            "\t[['eye', 'eyeball', 'eyes', 'body', 'cornea']]\n",
            "\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9_GN2_v1Re0i"
      },
      "source": [
        "## Generating Synonyms for a CPC \n",
        "\n",
        "Building on the above we can query for patents containing certain terms across CPC codes and examine how the predicted synonyms change in each of those CPC codes."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "h9mVoGKkTiiA"
      },
      "source": [
        "search_token = 'priming'\n",
        "words = ['priming', 'cleaning', 'maintenance',  'bonding', 'subbing', 'anchor']\n",
        "word_ids = tokenizer.convert_tokens_to_ids(words)\n",
        "\n",
        "query = r\"\"\"\n",
        "  SELECT publication_number, abstract, url\n",
        "  FROM `patents-public-data.google_patents_research.publications`,\n",
        "    UNNEST(cpc) as cpc\n",
        "  WHERE \n",
        "    cpc.code = '{}' AND\n",
        "    cpc.first = True AND\n",
        "    abstract like '% {} %'\n",
        "  LIMIT 100\n",
        "\"\"\""
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4_xI-NeOWRmQ",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1605026134962,
          "user_tz": 300,
          "elapsed": 11931,
          "user": {
            "displayName": "Rob Srebrovic",
            "photoUrl": "",
            "userId": "06004353344935214283"
          }
        },
        "outputId": "0eb872eb-c3f2-45df-bc91-0553ed54009b"
      },
      "source": [
        "cpc = 'B41J2/165'\n",
        "\n",
        "df = bq_client.query(query.format(cpc, search_token)).to_dataframe()\n",
        "\n",
        "tokenized_inputs = get_tokenized_input(df.abstract.to_list(), tokenizer)\n",
        "mlm_ids = get_mlm_ids_by_token('priming', tokenized_inputs)\n",
        "\n",
        "response, inputs, masked_text = bert_predictor.predict(\n",
        "    df.abstract.to_list(), mlm_ids)\n",
        "\n",
        "token_preds, term_preds = bert_topk_predictions(\n",
        "    response['mlm_logits'], mlm_ids, top_k=10)\n",
        "\n",
        "word_dict = find_rankings(words, word_ids, response['mlm_logits'], mlm_ids)\n",
        "\n",
        "print('Word positions for our term list:')\n",
        "for k, v in word_dict.items():\n",
        "  print(k, v)\n",
        "\n",
        "prediction_list = [x[0] for x in term_preds]\n",
        "all_predictions = [item for sublist in prediction_list for item in sublist]\n",
        "\n",
        "all_counts = collections.Counter(all_predictions)\n",
        "top_10 = collections.Counter(all_predictions).most_common(10)\n",
        "\n",
        "print('\\nMost common words predicted:')\n",
        "for t, _ in top_10:\n",
        "  print(t)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Word positions for our term list:\n",
            "priming {'average': 0.75, 'max': 6, 'min': 0, 'std': 1.984313483298443}\n",
            "cleaning {'average': 2.625, 'max': 3, 'min': 0, 'std': 0.9921567416492215}\n",
            "maintenance {'average': 22.875, 'max': 57, 'min': 3, 'std': 20.55138377336183}\n",
            "bonding {'average': 133.75, 'max': 260, 'min': 69, 'std': 69.24729236583912}\n",
            "subbing {'average': 1577.5, 'max': 1996, 'min': 1309, 'std': 242.1414875646055}\n",
            "anchor {'average': 4977.875, 'max': 15669, 'min': 1834, 'std': 4156.430633292825}\n",
            "\n",
            "Most common words predicted:\n",
            "cleaning\n",
            "capping\n",
            "priming\n",
            "sealing\n",
            "filling\n",
            "pumping\n",
            "flushing\n",
            "purging\n",
            "servicing\n",
            "maintenance\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "tfWNs9nOWLYB",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1605026189582,
          "user_tz": 300,
          "elapsed": 54612,
          "user": {
            "displayName": "Rob Srebrovic",
            "photoUrl": "",
            "userId": "06004353344935214283"
          }
        },
        "outputId": "2dc56674-c618-4f2d-9946-dbeb75bb6698"
      },
      "source": [
        "cpc = 'F04D9/041'\n",
        "\n",
        "df = bq_client.query(query.format(cpc, search_token)).to_dataframe()\n",
        "\n",
        "tokenized_inputs = get_tokenized_input(df.abstract.to_list(), tokenizer)\n",
        "mlm_ids = get_mlm_ids_by_token('priming', tokenized_inputs)\n",
        "\n",
        "response, inputs, masked_text = bert_predictor.predict(\n",
        "    df.abstract.to_list(), mlm_ids)\n",
        "\n",
        "token_preds, term_preds = bert_topk_predictions(\n",
        "    response['mlm_logits'], mlm_ids, top_k=10)\n",
        "\n",
        "word_dict = find_rankings(words, word_ids, response['mlm_logits'], mlm_ids)\n",
        "\n",
        "print('Word positions for our term list:')\n",
        "for k, v in word_dict.items():\n",
        "  print(k, v)\n",
        "\n",
        "prediction_list = [x[0] for x in term_preds]\n",
        "all_predictions = [item for sublist in prediction_list for item in sublist]\n",
        "\n",
        "all_counts = collections.Counter(all_predictions)\n",
        "top_10 = collections.Counter(all_predictions).most_common(10)\n",
        "\n",
        "print('\\nMost common words predicted:')\n",
        "for t, _ in top_10:\n",
        "  print(t)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Word positions for our term list:\n",
            "priming {'average': 7.7560975609756095, 'max': 109, 'min': 0, 'std': 23.986561774180547}\n",
            "cleaning {'average': 91.39024390243902, 'max': 483, 'min': 1, 'std': 133.53101001836941}\n",
            "maintenance {'average': 1291.0731707317073, 'max': 5662, 'min': 4, 'std': 1455.0359879925174}\n",
            "bonding {'average': 2474.0, 'max': 11616, 'min': 683, 'std': 2299.0895016474165}\n",
            "subbing {'average': 5893.609756097561, 'max': 20773, 'min': 1136, 'std': 4222.55737374153}\n",
            "anchor {'average': 6398.975609756098, 'max': 21392, 'min': 497, 'std': 4246.727684205259}\n",
            "\n",
            "Most common words predicted:\n",
            "priming\n",
            "starting\n",
            "pumping\n",
            "suction\n",
            "prime\n",
            "vacuum\n",
            "centrifugal\n",
            "flushing\n",
            "contained\n",
            "cleaning\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "khUmyje2aMzT"
      },
      "source": [
        "## Extending BERT - CPC Classifier\n",
        "\n",
        "A lot more can be done with the BERT trained model beyond synonym prediction. We can take the BERT outputs to do things such as:\n",
        "- Build classifiers for CPC codes (or anything else)\n",
        "- Tune a model on top of BERT ouputs to perform autocomplete\n",
        "- Perform semantic simialrity by training some type of siamese network on the BERT outputs\n",
        "\n",
        "Below we take the BERT outputs for 100 patents and build a tiny classifier to predict the first letter of the CPC code for a patent."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "MGDrVvIMaWK-"
      },
      "source": [
        "query = r'''\n",
        "  #standardSQL\n",
        "  SELECT DISTINCT\n",
        "    substr(cpc.code, 0, 1) as cpc_class,\n",
        "    res.abstract\n",
        "  FROM `patents-public-data.google_patents_research.publications` res,\n",
        "    UNNEST(cpc) as cpc\n",
        "    INNER JOIN `patents-public-data.patents.publications` pub ON \n",
        "      res.publication_number = pub.publication_number\n",
        "  WHERE \n",
        "    pub.publication_date >= 20000101 AND\n",
        "    res.country = 'United States' AND\n",
        "    cpc.first = True AND\n",
        "    RAND() < 0.1\n",
        "  LIMIT {}\n",
        "'''.format(200)\n",
        "\n",
        "df = bq_client.query(query).to_dataframe()\n",
        "df = df.sample(frac=1).reset_index(drop=True)\n",
        "\n",
        "cpc_classes = {\n",
        "    'A': 0, 'B': 1, 'C': 2, 'D': 3, 'E': 4, 'F': 5, 'G': 6, 'H': 7, 'Y': 8}\n",
        "\n",
        "texts = df.abstract.tolist()\n",
        "classes = [cpc_classes[x] for x in df.cpc_class.tolist()]"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WFHiw55EapKy"
      },
      "source": [
        "response, inputs, masked_text = bert_predictor.predict(texts)\n",
        "\n",
        "train_inputs = response['cls_token']\n",
        "train_labels = tf.convert_to_tensor(classes)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bt5NhcIkapPC",
        "executionInfo": {
          "status": "ok",
          "timestamp": 1605026477356,
          "user_tz": 300,
          "elapsed": 3345,
          "user": {
            "displayName": "Rob Srebrovic",
            "photoUrl": "",
            "userId": "06004353344935214283"
          }
        },
        "outputId": "0604e757-eb14-4029-aebe-6bdec542ea63"
      },
      "source": [
        "num_classes = len(cpc_classes)\n",
        "\n",
        "model = tf.keras.Sequential([\n",
        "  tf.keras.layers.Dense(128, activation='relu'),\n",
        "  tf.keras.layers.Dense(num_classes)\n",
        "])\n",
        "\n",
        "model.compile(\n",
        "    optimizer='adam',\n",
        "    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "    metrics=['accuracy'])\n",
        "\n",
        "history = model.fit(\n",
        "    x=train_inputs, \n",
        "    y=train_labels, \n",
        "    epochs=10, \n",
        "    validation_split=0.1)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch 1/10\n",
            "6/6 [==============================] - 2s 74ms/step - loss: 2.4930 - accuracy: 0.1880 - val_loss: 2.0529 - val_accuracy: 0.3500\n",
            "Epoch 2/10\n",
            "6/6 [==============================] - 0s 22ms/step - loss: 1.5419 - accuracy: 0.4371 - val_loss: 1.9152 - val_accuracy: 0.2500\n",
            "Epoch 3/10\n",
            "6/6 [==============================] - 0s 23ms/step - loss: 1.1766 - accuracy: 0.5592 - val_loss: 1.8375 - val_accuracy: 0.3500\n",
            "Epoch 4/10\n",
            "6/6 [==============================] - 0s 27ms/step - loss: 0.8675 - accuracy: 0.6735 - val_loss: 1.6896 - val_accuracy: 0.3000\n",
            "Epoch 5/10\n",
            "6/6 [==============================] - 0s 22ms/step - loss: 0.7060 - accuracy: 0.7969 - val_loss: 1.7911 - val_accuracy: 0.3500\n",
            "Epoch 6/10\n",
            "6/6 [==============================] - 0s 22ms/step - loss: 0.5098 - accuracy: 0.8422 - val_loss: 1.6971 - val_accuracy: 0.2500\n",
            "Epoch 7/10\n",
            "6/6 [==============================] - 0s 27ms/step - loss: 0.4247 - accuracy: 0.9095 - val_loss: 1.5988 - val_accuracy: 0.3000\n",
            "Epoch 8/10\n",
            "6/6 [==============================] - 0s 19ms/step - loss: 0.3788 - accuracy: 0.9378 - val_loss: 1.6432 - val_accuracy: 0.3500\n",
            "Epoch 9/10\n",
            "6/6 [==============================] - 0s 21ms/step - loss: 0.2781 - accuracy: 0.9606 - val_loss: 1.6566 - val_accuracy: 0.3500\n",
            "Epoch 10/10\n",
            "6/6 [==============================] - 0s 25ms/step - loss: 0.2773 - accuracy: 0.9677 - val_loss: 1.5322 - val_accuracy: 0.3500\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "xSzuUWQcAvIR"
      },
      "source": [
        ""
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}